{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright (c) Microsoft Corporation. All rights reserved.  \n",
    "Licensed under the MIT License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference PyTorch Bert Model for High Performance in ONNX Runtime"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial, you'll be introduced to how to load a Bert model from PyTorch, convert it to ONNX, and inference it for high performance using ONNX Runtime with transformer optimization. In the following sections, we are going to use the Bert model trained with Stanford Question Answering Dataset (SQuAD) dataset as an example. Bert SQuAD model is used in question answering scenarios, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0. Prerequisites ##\n",
    "First you need to check if the following packages exist and install them if needed. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Install a pip package in the current Jupyter kernel\n",
    "import sys\n",
    "!{sys.executable} -m pip install wget               # used to download data files     \n",
    "!{sys.executable} -m pip install --user torch==1.3.1 torchvision==0.4.2+cpu -f https://download.pytorch.org/whl/torch_stable.html\n",
    "!{sys.executable} -m pip install transformers       # used to load pytorch bert model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load Pretrained Bert model ##"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We begin by downloading the data files and store them in the specified location. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# Create a directory to store predict file\n",
    "output_dir = \"./pytorch_output\"\n",
    "cache_dir = \"./pytorch_squad\"\n",
    "predict_file = os.path.join(cache_dir, \"dev-v1.1.json\")\n",
    "# create cache dir\n",
    "if not os.path.exists(cache_dir):\n",
    "    os.makedirs(cache_dir)\n",
    "    \n",
    "# Download the file\n",
    "predict_file_url = \"https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json\"\n",
    "if not os.path.exists(predict_file):\n",
    "    import wget\n",
    "    print(\"Start downloading predict file.\")\n",
    "    wget.download(predict_file_url, predict_file)\n",
    "    print(\"Predict file downloaded.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Specify some model config variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define some variables\n",
    "model_type = \"bert\"\n",
    "model_name_or_path = \"bert-base-cased\"\n",
    "max_seq_length = 128\n",
    "doc_stride = 128\n",
    "max_query_length = 64\n",
    "per_gpu_eval_batch_size = 1\n",
    "eval_batch_size = 1\n",
    "import torch\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Start to load model from pretrained. This step could take a few minutes. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The following code is adapted from HuggingFace transformers\n",
    "# https://github.com/huggingface/transformers/blob/master/examples/run_squad.py#L290\n",
    "\n",
    "from transformers import (WEIGHTS_NAME, BertConfig, BertForQuestionAnswering, BertTokenizer)\n",
    "from torch.utils.data import (DataLoader, SequentialSampler)\n",
    "\n",
    "# Load pretrained model and tokenizer\n",
    "config_class, model_class, tokenizer_class = (BertConfig, BertForQuestionAnswering, BertTokenizer)\n",
    "config = config_class.from_pretrained(model_name_or_path, cache_dir=cache_dir)\n",
    "tokenizer = tokenizer_class.from_pretrained(model_name_or_path, do_lower_case=True, cache_dir=cache_dir)\n",
    "model = model_class.from_pretrained(model_name_or_path,\n",
    "                                    from_tf=False,\n",
    "                                    config=config,\n",
    "                                    cache_dir=cache_dir)\n",
    "# load_and_cache_examples\n",
    "from transformers.data.processors.squad import SquadV2Processor\n",
    "\n",
    "processor = SquadV2Processor()\n",
    "examples = processor.get_dev_examples(None, filename=predict_file)\n",
    "\n",
    "from transformers import squad_convert_examples_to_features\n",
    "features, dataset = squad_convert_examples_to_features( \n",
    "            examples=examples,\n",
    "            tokenizer=tokenizer,\n",
    "            max_seq_length=max_seq_length,\n",
    "            doc_stride=doc_stride,\n",
    "            max_query_length=max_query_length,\n",
    "            is_training=False,\n",
    "            return_dataset='pt'\n",
    "        )\n",
    "\n",
    "cached_features_file = os.path.join(cache_dir, 'cached_{}_{}_{}'.format(\n",
    "        'dev',\n",
    "        list(filter(None, model_name_or_path.split('/'))).pop(),\n",
    "        str(384))\n",
    "    )\n",
    "\n",
    "torch.save({\"features\": features, \"dataset\": dataset}, cached_features_file)\n",
    "print(\"Saved features into cached file \", cached_features_file)\n",
    "\n",
    "# create output dir\n",
    "if not os.path.exists(output_dir):\n",
    "    os.makedirs(output_dir)\n",
    "    \n",
    "n_gpu = torch.cuda.device_count()\n",
    "# eval_batch_size = 8 * max(1, n_gpu)\n",
    "\n",
    "eval_sampler = SequentialSampler(dataset)\n",
    "eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=eval_batch_size)\n",
    "\n",
    "# multi-gpu evaluate\n",
    "if n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):\n",
    "    model = torch.nn.DataParallel(model)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Export the loaded model ##\n",
    "Once the model is loaded, we can export the loaded PyTorch model to ONNX."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Eval!\n",
    "print(\"***** Running evaluation {} *****\")\n",
    "print(\"  Num examples = \", len(dataset))\n",
    "print(\"  Batch size = \", eval_batch_size)\n",
    "\n",
    "output_model_path = './pytorch_squad/bert-base-cased-squad.onnx'    \n",
    "inputs = {}\n",
    "outputs= {}\n",
    "# Get the first batch of data to run the model and export it to ONNX\n",
    "batch = dataset[0]\n",
    "\n",
    "# Set model to inference mode, which is required before exporting the model because some operators behave differently in \n",
    "# inference and training mode.\n",
    "model.eval()\n",
    "batch = tuple(t.to(device) for t in batch)\n",
    "inputs = {\n",
    "    'input_ids':      batch[0].reshape(1, 128),                         # using batch size = 1 here. Adjust as needed.\n",
    "    'attention_mask': batch[1].reshape(1, 128),\n",
    "    'token_type_ids': batch[2].reshape(1, 128)\n",
    "}\n",
    "\n",
    "with torch.no_grad():\n",
    "    symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}\n",
    "    torch.onnx.export(model,                                            # model being run\n",
    "                      (inputs['input_ids'],                             # model input (or a tuple for multiple inputs)\n",
    "                       inputs['attention_mask'], \n",
    "                       inputs['token_type_ids']), \n",
    "                      output_model_path,                                # where to save the model (can be a file or file-like object)\n",
    "                      opset_version=11,                                 # the ONNX version to export the model to\n",
    "                      do_constant_folding=True,                         # whether to execute constant folding for optimization\n",
    "                      input_names=['input_ids',                         # the model's input names\n",
    "                                   'input_mask', \n",
    "                                   'segment_ids'],\n",
    "                      output_names=['start', 'end'],                    # the model's output names\n",
    "                      dynamic_axes={'input_ids': symbolic_names,        # variable length axes\n",
    "                                    'input_mask' : symbolic_names,\n",
    "                                    'segment_ids' : symbolic_names,\n",
    "                                    'start' : symbolic_names,\n",
    "                                    'end' : symbolic_names})\n",
    "    print(\"Model exported at \", output_model_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Inference the Exported Model with ONNX Runtime ##\n",
    "\n",
    "#### Install ONNX Runtime\n",
    "Install ONNX Runtime if you haven't done so already. \n",
    "\n",
    "Install `onnxruntime` to use CPU features, or `onnxruntime-gpu` to use GPU. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ONNXRUNTIME = 'onnxruntime'\n",
    "# Install ONNX Runtime\n",
    "if torch.cuda.is_available():\n",
    "    ## Install onnxruntime-gpu if cuda is available\n",
    "    ONNXRUNTIME = 'onnxruntime-gpu'\n",
    "\n",
    "import sys\n",
    "!{sys.executable} -m pip install -U $ONNXRUNTIME"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we are ready to inference the model with ONNX Runtime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnxruntime as rt  \n",
    "import time\n",
    "\n",
    "sess_options = rt.SessionOptions()\n",
    "\n",
    "# Set graph optimization level to ORT_ENABLE_EXTENDED to enable bert optimization.\n",
    "sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED\n",
    "\n",
    "# To enable model serialization and store the optimized graph to desired location.\n",
    "sess_options.optimized_model_filepath = os.path.join(output_dir, \"optimized_model.onnx\")\n",
    "session = rt.InferenceSession(output_model_path, sess_options)\n",
    "\n",
    "# evaluate the model\n",
    "start = time.time()\n",
    "res = session.run(None, {\n",
    "          'input_ids': inputs['input_ids'].cpu().numpy(),\n",
    "          'input_mask': inputs['attention_mask'].cpu().numpy(),\n",
    "          'segment_ids': inputs['token_type_ids'].cpu().numpy()\n",
    "        })\n",
    "end = time.time()\n",
    "print(\"ONNX Runtime inference time: \", end - start)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get perf numbers from the original PyTorch model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time.time()\n",
    "outputs = model(**inputs)\n",
    "end = time.time()\n",
    "print(\"PyTorch Inference time = \", end - start)\n",
    "\n",
    "print(\"***** Verifying correctness *****\")\n",
    "import numpy as np\n",
    "for i in range(2):\n",
    "    print('PyTorch and ORT matching numbers:', np.allclose(res[i], outputs[i].cpu().detach().numpy(), rtol=1e-04, atol=1e-05))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
